{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/idealo/image-super-resolution/blob/master/notebooks/ISR_Training_Tutorial.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QJ4sSTzDWAao"
   },
   "source": [
    "# Install ISR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KCd2ZuS4V6Z0"
   },
   "outputs": [],
   "source": [
    "!pip install ISR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R_BXVVxnN1sx"
   },
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aJCKjYk-MA-p"
   },
   "source": [
    "\n",
    "## Get the training data\n",
    "Get your data to train the model. The div2k dataset linked here is for a scaling factor of 2. Beware of this later when training the model.\n",
    "\n",
    "(for more options on how to get you data on Colab notebooks visit https://colab.research.google.com/notebooks/io.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ytGnfdDo77l-"
   },
   "outputs": [],
   "source": [
    "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip\n",
    "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip\n",
    "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip\n",
    "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CMUgC2k21lC9"
   },
   "outputs": [],
   "source": [
    "!mkdir div2k\n",
    "!unzip -q DIV2K_valid_LR_bicubic_X2.zip -d div2k\n",
    "!unzip -q DIV2K_train_LR_bicubic_X2.zip -d div2k\n",
    "!unzip -q DIV2K_train_HR.zip -d div2k\n",
    "!unzip -q DIV2K_valid_HR.zip -d div2k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LeNFEtJeNKTj"
   },
   "source": [
    "## Create the models\n",
    "Import the models from the ISR package and create\n",
    "\n",
    "- a RRDN super scaling network\n",
    "- a discriminator network for GANs training\n",
    "- a VGG19 feature extractor to train with a perceptual loss function\n",
    "\n",
    "Carefully select\n",
    "- 'x': this is the upscaling factor (2 by default)\n",
    "- 'layers_to_extract': these are the layers from the VGG19 that will be used in the perceptual loss (leave the default if you're not familiar with it)\n",
    "- 'lr_patch_size': this is the size of the patches that will be extracted from the LR images and fed to the ISR network during training time\n",
    "\n",
    "Play around with the other architecture parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34.0
    },
    "colab_type": "code",
    "id": "8e2RSZq4BY_T",
    "outputId": "71efc07e-be9d-4796-bc24-6ce4c5285330"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from ISR.models import RRDN\n",
    "from ISR.models import Discriminator\n",
    "from ISR.models import Cut_VGG19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88.0
    },
    "colab_type": "code",
    "id": "celHq8FjB5vA",
    "outputId": "150bf398-c71f-4c0b-fc88-2918b32daa29"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "lr_train_patch_size = 40\n",
    "layers_to_extract = [5, 9]\n",
    "scale = 2\n",
    "hr_train_patch_size = lr_train_patch_size * scale\n",
    "\n",
    "rrdn  = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)\n",
    "f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)\n",
    "discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aaXfv12EPzal"
   },
   "source": [
    "## Give the models to the Trainer\n",
    "The Trainer object will combine the networks, manage your training data and keep you up-to-date with the training progress through Tensorboard and the command line.\n",
    "\n",
    "Here we do not use  the pixel-wise MSE but only the perceptual loss by specifying the respective weights in `loss_weights`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6AV0m-s8OaqI"
   },
   "outputs": [],
   "source": [
    "from ISR.train import Trainer\n",
    "loss_weights = {\n",
    "  'generator': 0.0,\n",
    "  'feature_extractor': 0.0833,\n",
    "  'discriminator': 0.01\n",
    "}\n",
    "losses = {\n",
    "  'generator': 'mae',\n",
    "  'feature_extractor': 'mse',\n",
    "  'discriminator': 'binary_crossentropy'\n",
    "} \n",
    "\n",
    "log_dirs = {'logs': './logs', 'weights': './weights'}\n",
    "\n",
    "learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}\n",
    "\n",
    "flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}\n",
    "\n",
    "trainer = Trainer(\n",
    "    generator=rrdn,\n",
    "    discriminator=discr,\n",
    "    feature_extractor=f_ext,\n",
    "    lr_train_dir='div2k/DIV2K_train_LR_bicubic/X2/',\n",
    "    hr_train_dir='div2k/DIV2K_train_HR/',\n",
    "    lr_valid_dir='div2k/DIV2K_train_LR_bicubic/X2/',\n",
    "    hr_valid_dir='div2k/DIV2K_train_HR/',\n",
    "    loss_weights=loss_weights,\n",
    "    learning_rate=learning_rate,\n",
    "    flatness=flatness,\n",
    "    dataname='div2k',\n",
    "    log_dirs=log_dirs,\n",
    "    weights_generator=None,\n",
    "    weights_discriminator=None,\n",
    "    n_validation=40,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5UpepsY77r5M"
   },
   "source": [
    "Choose epoch number, steps and batch size and start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 377.0
    },
    "colab_type": "code",
    "id": "YnvSnZUa7rA6",
    "outputId": "97aa04ea-685d-411f-f0ba-1e1b0318480b"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Training details:\n",
      "Generator settings:\n",
      "{'C': 4, 'D': 3, 'G': 64, 'G0': 64, 'T': 10, 'x': 2}\n",
      "Using GAN discriminator.\n",
      "Using high level features loss:\n",
      "feature_extractor layers: [5, 9]\n",
      "Training session name identifier: rrdn-C4-D3-G64-G064-T10-x2_div2k-vgg19-5-9-srgan-large\n",
      "Input dir: div2k/DIV2K_train_LR_bicubic/X2/\n",
      "Patch size: 40\n",
      "Saving weights under: ./weights/rrdn-C4-D3-G64-G064-T10-x2/div2k-vgg19-5-9-srgan-large\n",
      "Saving tensorboard logs under: ./logs/rrdn-C4-D3-G64-G064-T10-x2/div2k-vgg19-5-9-srgan-large\n",
      "Epoch 0/1\n",
      "Current learning rate: 0.00039999998989515007\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
      "  'Discrepancy between trainable weights and collected trainable'\n",
      " 85%|████████▌ | 17/20 [09:11<01:32, 30.69s/it]"
     ]
    }
   ],
   "source": [
    "trainer.train(\n",
    "    epochs=1,\n",
    "    steps_per_epoch=20,\n",
    "    batch_size=4,\n",
    "    monitored_metrics={'val_PSNR_Y': 'max'}\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "name": "ISR Traininig tutorial.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
